import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sklearn.neighbors import KernelDensity
from sklearn.metrics import roc_auc_score, average_precision_score
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
from PIL import Image
import numpy as np
import logging
import itertools
import csv

# Enable benchmark mode in cudnn to speed up convolutions
torch.backends.cudnn.benchmark = True

# Logging setup
logging.basicConfig(filename='training.log', level=logging.INFO)


class CustomImageDataset(Dataset):
    """Dataset class to load images and labels from folders."""

    def __init__(self, roots, transform=None, color_folders=None, subset_ratio=1):
        self.images = []
        self.labels = []
        self.transform = transform
        self.subset_ratio = subset_ratio

        for label, root in enumerate(roots):
            for color_folder in color_folders:
                color_path = os.path.join(root, color_folder)
                images = sorted(os.listdir(color_path))
                # Take subset of the images
                images = images[:int(len(images) * self.subset_ratio)]

                for image in images:
                    image_path = os.path.join(color_path, image)
                    self.images.append(image_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        return image, label


class CLUB(nn.Module):
    """Mutual Information Contrastive Learning Upper Bound."""

    def __init__(self, x_dim, y_dim, hidden_size):
        super(CLUB, self).__init__()
        self.p_mu = nn.Sequential(
            nn.Linear(x_dim, hidden_size // 2),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.2),
            nn.Linear(hidden_size // 2, y_dim)
        )
        self.p_logvar = nn.Sequential(
            nn.Linear(x_dim, hidden_size // 2),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.2),
            nn.Linear(hidden_size // 2, y_dim)
        )

    def get_mu_logvar(self, x_samples):
        mu = self.p_mu(x_samples)
        logvar = self.p_logvar(x_samples)
        logvar = torch.clamp(logvar, min=-10, max=10)
        return mu, logvar

    def forward(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)
        positive = - (mu - y_samples) ** 2 / 2. / logvar.exp()
        prediction_1 = mu.unsqueeze(1)
        y_samples_1 = y_samples.unsqueeze(0)
        negative = - ((y_samples_1 - prediction_1) ** 2).mean(dim=1) / 2. / logvar.exp()

        return (positive.sum(dim=-1) - negative.sum(dim=-1)).mean() + 1e-6

    def loglikeli(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)
        return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0) + 1e-6

    def learning_loss(self, x_samples, y_samples):
        return - self.loglikeli(x_samples, y_samples)


def get_gmm_param(gamma, z, device):
    N = gamma.shape[0]
    ceta = torch.sum(gamma, dim=0) / N
    mean = torch.sum(gamma.unsqueeze(-1) * z.unsqueeze(1), dim=0)
    mean = mean / torch.sum(gamma, dim=0).unsqueeze(-1)
    z_mean = (z.unsqueeze(1) - mean.unsqueeze(0))
    cov = torch.sum(gamma.unsqueeze(-1).unsqueeze(-1) * z_mean.unsqueeze(-1) * z_mean.unsqueeze(-2), dim=0)
    cov = cov / torch.sum(gamma, dim=0).unsqueeze(-1).unsqueeze(-1)
    return ceta.to(device), mean.to(device), cov.to(device)


def sample_energy(ceta, mean, cov, z_background, n_gmm, bs, device):
    e = torch.zeros(bs, device=device)
    cov_eps = torch.eye(mean.shape[1], device=device) * 1e-6

    for k in range(n_gmm):
        miu_k = mean[k].unsqueeze(0)
        d_k = z_background - miu_k
        d_k_norm = torch.norm(d_k, dim=1, keepdim=True)
        d_k_normed = d_k / (d_k_norm + 1e-6)

        inv_cov = torch.linalg.inv(cov[k] + cov_eps).unsqueeze(0).expand(bs, -1, -1)
        dot_product = torch.bmm(d_k_normed.unsqueeze(1), torch.bmm(inv_cov, d_k_normed.unsqueeze(2))).squeeze()
        dot_product = torch.clamp(dot_product, min=-0.1, max=0.1)

        e_k = torch.exp(-0.5 * dot_product) * ceta[k]
        e += e_k

    return -torch.log(e)


def compute_loss(x, reconstructed_x, z_background, gamma, device):
    bs, n_gmm = gamma.shape[0], gamma.shape[1]
    reconstruction_loss = reconstruct_error(x, reconstructed_x)
    ceta, mean, cov = get_gmm_param(gamma, z_background, device)
    energy_sum_background = sample_energy(ceta, mean, cov, z_background, n_gmm, bs, device).mean()

    p = torch.tensor(0.0).to(device)
    for k in range(n_gmm):
        p += torch.sum(1 / torch.diagonal(cov[k], 0))

    return reconstruction_loss, energy_sum_background, p


def reconstruct_error(x, reconstructed_x):
    mse_loss = torch.nn.functional.mse_loss(x, reconstructed_x, reduction='mean')
    return mse_loss


def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


class Estimator(nn.Module):
    """Network to estimate posterior probabilities."""

    def __init__(self, z_dims, n_components):
        super(Estimator, self).__init__()
        self.fc1 = nn.Linear(z_dims, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.fc4 = nn.Linear(10, n_components)

    def forward(self, z):
        z = F.relu(self.fc1(z))
        z = F.relu(self.fc2(z))
        z = F.relu(self.fc3(z))
        z = F.dropout(torch.tanh(self.fc4(z)), 0.2)
        gamma = F.softmax(z, dim=1)
        return z, gamma


class Encoder(nn.Module):
    """Encoder network with convolutional layers."""

    def __init__(self, zdims=128):
        super(Encoder, self).__init__()
        self.zdims = zdims
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.1, inplace=True),
            nn.Flatten(), nn.Linear(512 * 1 * 1, self.zdims)
        )
        self.fc_color = nn.Linear(zdims, zdims)
        self.fc_logvar_color = nn.Linear(zdims, zdims)
        self.fc_digit = nn.Linear(zdims, zdims)
        self.fc_logvar_digit = nn.Linear(zdims, zdims)

    def forward(self, x):
        x = self.model(x)
        color_latent = reparameterize(self.fc_color(x), self.fc_logvar_color(x))
        digit_latent = reparameterize(self.fc_digit(x), self.fc_logvar_digit(x))
        return color_latent, digit_latent


class Decoder(nn.Module):
    """Decoder network to reconstruct images."""

    def __init__(self, zdims=128):
        super(Decoder, self).__init__()
        self.color_decoder = nn.Sequential(
            nn.Linear(zdims, 512), nn.LeakyReLU(0.1, inplace=True),
            nn.Unflatten(1, (512, 1, 1)),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, output_padding=1),
            nn.BatchNorm2d(256), nn.LeakyReLU(0.1, inplace=True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, output_padding=1),
            nn.BatchNorm2d(128), nn.LeakyReLU(0.1, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64), nn.LeakyReLU(0.1, inplace=True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Sigmoid()
        )

        self.digit_decoder = nn.Sequential(
            nn.Linear(zdims, 512), nn.LeakyReLU(0.1, inplace=True),
            nn.Unflatten(1, (512, 1, 1)),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, output_padding=1),
            nn.BatchNorm2d(256), nn.LeakyReLU(0.1, inplace=True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, output_padding=1),
            nn.BatchNorm2d(128), nn.LeakyReLU(0.1, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64), nn.LeakyReLU(0.1, inplace=True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Sigmoid()
        )

    def forward(self, z_color, z_digit):
        recon_color = self.color_decoder(z_color)
        recon_digit = self.digit_decoder(z_digit)
        recon_combined = recon_color + recon_digit
        return recon_combined, recon_color, recon_digit


class MMD_VAE(nn.Module):
    """Main Variational Autoencoder (VAE) Model."""

    def __init__(self, encoder, decoder, club, zdims=128, n_components=3):
        super(MMD_VAE, self).__init__()
        self.zdims = zdims
        self.encoder = encoder
        self.decoder = decoder
        self.club = club
        self.estimator = Estimator(zdims, n_components)

    def forward(self, x):
        z_background, z_subject = self.encoder(x)
        recon_combined, recon_color, recon_digit = self.decoder(z_background, z_subject)
        _, gamma = self.estimator(z_background)
        return recon_combined, recon_color, recon_digit, gamma, z_background, z_subject


def weights_init(m):
    """Weight initialization for network layers."""
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)


def train(device, data_loader, test_loader, model, mi_model, n_epochs, lr, wd, lam1, lam2, zdims, num, color,
          combination):
    """Training function for the VAE model."""
    # Initialize model weights
    mi_model.apply(weights_init)
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    optimizer_mi = optim.AdamW(mi_model.parameters(), lr=lr, weight_decay=wd)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, verbose=True)
    scheduler_mi = ReduceLROnPlateau(optimizer_mi, mode='min', factor=0.5, patience=20, verbose=True)

    # Track best AUC and PRC
    max_auc = 0.0
    max_prc = 0.0

    with open(f'./result_model/mnist_mmd_vae_{num}_performance.csv', 'w', newline='') as csvfile:
        fieldnames = ['Epoch', 'Best AUC', 'Best PRC', 'AUC Background', 'PRC Background', 'LR', 'Lambda1', 'Lambda2',
                      'Z_dims', 'WD']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        for epoch in range(n_epochs):
            total_loss_epoch = 0.0

            for data, _ in data_loader:
                mi_model.eval()
                data = data.to(device)
                optimizer.zero_grad()
                recon_combined, _, _, gamma, z_background, z_subject = model(data)
                mi_loss = mi_model.learning_loss(z_subject, z_background)
                reconstruction_loss, energy_sum_background, _ = compute_loss(data, recon_combined, z_background, gamma,
                                                                             device)
                loss = reconstruction_loss + lam1 * energy_sum_background + lam2 * mi_loss
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                total_loss_epoch += loss.item()

                for _ in range(10):
                    mi_model.train()
                    optimizer_mi.zero_grad()
                    recon_combined, _, _, _, z_background, z_subject = model(data)
                    mi_loss = mi_model.learning_loss(z_background, z_subject)
                    mi_loss.backward()
                    nn.utils.clip_grad_norm_(mi_model.parameters(), max_norm=1.0)
                    optimizer_mi.step()

            scheduler.step(total_loss_epoch)
            scheduler_mi.step(mi_loss)

    torch.save(model.state_dict(), f'./models/model_{num}.pth')
    return max_auc, max_prc


def test(model, device, test_loader, train_loader, combination, epoch):
    """Testing function for the VAE model."""
    model.eval()
    test_labels_list = []
    test_subject_data = []
    train_subject_data = []
    train_background_data = []

    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            _, _, _, _, z_background, z_subject = model(data)
            is_normal = target.cpu().numpy()[:, None] == np.array(combination)[None, :]
            labels = np.logical_not(is_normal.any(axis=1)).astype(int)
            test_labels_list.extend(labels)

            test_subject_data.append(z_subject.cpu())

        for data, _ in train_loader:
            data = data.to(device)
            _, _, _, _, z_background, z_subject = model(data)
            train_subject_data.append(z_subject.cpu())
            train_background_data.append(z_background.cpu())

    test_subject_data = torch.cat(test_subject_data, dim=0).numpy()
    train_subject_data = torch.cat(train_subject_data, dim=0).numpy()

    best_auc_kde = 0.0
    best_prc = 0.0

    kde_subject = KernelDensity().fit(train_subject_data)
    log_dens_subject = kde_subject.score_samples(test_subject_data)
    kde_scores_subject = np.exp(log_dens_subject)

    auc_kde = roc_auc_score(test_labels_list, kde_scores_subject)
    aucprc = average_precision_score(test_labels_list, kde_scores_subject)

    if auc_kde > best_auc_kde or aucprc > best_prc:
        best_auc_kde = auc_kde
        best_prc = aucprc

    return best_auc_kde, best_prc, 0, 0, epoch


def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    batch_size = 1024
    transform = transforms.Compose([transforms.ToTensor()])

    train_root = './data/MNIST_four_colored/train'
    test_root = './data/MNIST_four_colored/test'
    color = 'green'

    test_root_dirs = [os.path.join(test_root, str(i)) for i in range(10)]
    test_dataset = CustomImageDataset(roots=test_root_dirs, transform=transform, color_folders=[color],
                                      subset_ratio=0.25)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=8)

    for num in range(10):
        combination = tuple(j for j in range(10) if j == num)
        train_root_dirs = [os.path.join(train_root, str(i)) for i in combination]
        train_dataset = CustomImageDataset(roots=train_root_dirs, transform=transform,
                                           color_folders=['blue', 'yellow', 'white'], subset_ratio=0.25)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=8)

        estimator_name = "CLUB"
        lam1 = 0.1
        lam2 = 0.0001
        zdims = 128
        lr = 1e-3
        wd = 0.001
        n_components = 3

        try:
            club = eval(estimator_name)(zdims, zdims, 256).to(device)
            encoder = Encoder(zdims).to(device)
            decoder = Decoder(zdims).to(device)
            model = MMD_VAE(encoder, decoder, club, zdims, n_components).to(device)

            auc_values = train(device, train_loader, test_loader, model, club, n_epochs=100, lr=lr,
                               wd=wd, lam1=lam1, lam2=lam2, zdims=zdims, num=num, color=color,
                               combination=combination)
            max_auc, max_prc = auc_values
            logging.info(f"Finished training model for class {num} with AUC {max_auc}, PRC {max_prc}")

        except Exception as e:
            logging.error(f"Exception occurred: {e}")
            continue


if __name__ == "__main__":
    main()
